import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import sticky_price as sp
import sticky_wage as sw


'''Part 1: Figures and tables in paper.'''

prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']


def fig2(ss, T=200, rho_G=0.9, xlim=None, ylim=None, savefig=False, figsize=None, GHH_text_imp=None,
         sep_text_imp=None, GHH_text_cum=None, sep_text_cum=None, figname=None):
    """The tradeoff between fiscal multipliers and MPEs"""
    # Government spending shock
    dG = 0.01 * rho_G ** np.arange(T)

    # Sticky-price model
    sp_ss = ss['sticky_price']
    sp_M, sp_ir, sp_fm = dict(), dict(), dict()
    sp_m0, sp_mT = np.empty(len(sp_ss)), np.empty(len(sp_ss))
    MPE = np.empty(len(sp_ss))
    for k in range(len(sp_ss)):
        # impulse responses
        sp_M[k] = sp.sp_benchmark(sp_ss[k], T)
        sp_ir[k] = get_ir(sp_M[k], dG)
        sp_ir[k].update({'G': dG})

        # fiscal multipliers and MPEs in vector form
        sp_df = (1 / (1 + sp_ss[k]['r'])) ** np.arange(T)
        sp_m0[k] = sp_ir[k]['Y'][0] / dG[0]
        sp_mT[k] = np.sum(sp_ir[k]['Y'] * sp_df) / np.sum(dG * sp_df)
        MPE[k] = sp.mpe_conversion(sp_ss[k])

    # Sticky-wage model
    sw_ss = ss['sticky_wage']
    sw_M = sw.sw_benchmark(sw_ss, T)
    sw_ir = get_ir(sw_M, dG)
    sw_ir.update({'G': dG})
    sw_df = (1 / (1 + sw_ss['r'])) ** np.arange(T)
    sw_m0 = sw_ir['Y'][0] / dG[0]
    sw_mT = np.sum(sw_ir['Y'] * sw_df) / np.sum(dG * sw_df)

    # Plot settings
    if figname is None:
        figname = ''

    if xlim is None:
        xlim = [-0.01, 0.24]

    if ylim is None:
        ylim = [0.0, 6.5]

    if figsize is None:
        figsize = (4, 3.5)

    if GHH_text_imp is None:
        GHH_text_imp = (0.01, sp_m0[10])

    if GHH_text_cum is None:
        GHH_text_cum = (0.01, sp_m0[10])

    if sep_text_imp is None:
        sep_text_imp = (MPE[0] - 0.05, sp_m0[0] - 0.5)

    if sep_text_cum is None:
        sep_text_cum = (MPE[0] - 0.05, sp_m0[0] - 0.5)

    # Plot box
    fig, ax = plt.subplots(1, figsize=figsize)
    ax.plot(MPE, sp_m0, label=r'sticky price, all', linewidth=2, marker='o', color=colors[0])
    box = patches.Rectangle((0, 0.6), 0.04, 1.4, alpha=0.5, color=colors[2])
    ax.annotate(r'GHH', xy=GHH_text_imp)
    ax.annotate(r'separable', xy=sep_text_imp)
    ax.add_patch(box)
    ax.plot(0, sw_m0, label=r'sticky wage, sep', color=colors[3], marker='D', linestyle='')
    ax.set_title(r'Impact Multiplier')
    ax.set_xlabel(r'Annual MPE')
    ax.set_ylabel('fiscal multiplier')
    ax.set_ylim(ylim)
    ax.set_xlim(xlim)
    plt.legend()
    plt.grid()
    plt.tight_layout()
    if savefig:
        plt.savefig('import_export/figures/fig_box_impact' + figname + '.pdf', transparent=True)

    fig, ax = plt.subplots(1, figsize=figsize)
    ax.plot(MPE, sp_mT, label=r'sticky price, all', linewidth=2, marker='o', color=colors[0])
    box = patches.Rectangle((0, 0.6), 0.04, 1.4, alpha=0.5, color=colors[2])
    ax.add_patch(box)
    ax.annotate(r'GHH', xy=GHH_text_cum)
    ax.annotate(r'separable', xy=sep_text_cum)
    ax.plot(0, sw_mT, label=r'sticky wage, sep', color=colors[3], linestyle='', marker='D')
    ax.set_title(r'Cumulative Multiplier')
    ax.set_xlabel(r'Annual MPE')
    ax.set_ylabel('fiscal multiplier')
    ax.set_ylim(ylim)
    ax.set_xlim(xlim)
    plt.legend()
    plt.grid()
    plt.tight_layout()
    if savefig:
        plt.savefig('import_export/figures/fig_box_cumulative' + figname + '.pdf', transparent=True)

    # Report multipliers on plot
    out = dict()
    out['sticky_wage'], out['sticky_price'] = dict(), dict()
    out['sticky_wage']['impact'], out['sticky_price']['impact'] = sw_m0, sp_m0
    out['sticky_wage']['cumulative'], out['sticky_price']['cumulative'] = sw_mT, sp_mT
    out['sticky_price']['MPE'] = MPE
    return out


def table2(ss, T=200, rho_G=0.9, **kwargs):
    """Cumulative Fiscal Multipliers with Alternative Monetary and Fiscal Policies."""
    # Gov't spending shock
    dG = 0.01 * rho_G ** np.arange(T)

    # Sticky-wage model
    sw_ss = ss['sticky_wage']
    model_names_sw = ['taylor', 'benchmark', 'peg', 'low_rhoB', 'high_rhoB']
    model_funs_sw = [sw.sw_taylor, sw.sw_benchmark, sw.sw_peg, sw.sw_benchmark, sw.sw_benchmark]
    M, ir, m0, mT = dict(), dict(), dict(), dict()
    for l, f in zip(model_names_sw, model_funs_sw):
        M[l], ir[l], m0[l], mT[l] = dict(), dict(), dict(), dict()
        # set persistence of debt
        if l == 'low_rhoB':
            sw_ss['rho_B'] = 0
        elif l == 'high_rhoB':
            sw_ss['rho_B'] = 0.95
        else:
            sw_ss['rho_B'] = 0.9

        # jacobians
        M[l]['sticky_wage'] = f(sw_ss, T, **kwargs)

        # impulse responses
        ir[l]['sticky_wage'] = get_ir(M[l]['sticky_wage'], dG)

        # fiscal multipliers
        df = (1 / (1 + sw_ss['r'])) ** np.arange(T)
        m0[l]['sticky_wage'] = ir[l]['sticky_wage']['Y'][0] / dG[0]
        mT[l]['sticky_wage'] = np.sum(ir[l]['sticky_wage']['Y'] * df) / np.sum(dG * df)

    # Sticky-price model
    sp_ss = ss['sticky_price']
    model_names_sp = ['taylor', 'benchmark', 'peg']
    model_funs_sp = [sp.sp_taylor, sp.sp_benchmark, sp.sp_peg]
    for l, f in zip(model_names_sp, model_funs_sp):
        for k, v in zip([0, 5, 10], ['separable', 'middle', 'GHH']):
            # jacobians
            M[l][v] = f(sp_ss[k], T)

            # impulse responses
            ir[l][v] = get_ir(M[l][v], dG)

            # fiscal multipliers
            df = (1 / (1 + sp_ss[k]['r'])) ** np.arange(T)
            m0[l][v] = ir[l][v]['Y'][0] / dG[0]
            mT[l][v] = np.sum(ir[l][v]['Y'] * df) / np.sum(dG * df)

    # 3. Formatting
    m0, mT = pd.DataFrame.from_dict(m0), pd.DataFrame.from_dict(mT)
    m0 = m0.reindex(index=['separable', 'middle', 'GHH', 'sticky_wage'])
    mT = mT.reindex(index=['separable', 'middle', 'GHH', 'sticky_wage'])

    return m0, mT


def fig3(ss, T=200, rho_G=0.9, Tplot=61, savefig=False, figsize=None):
    """Government Spending and the Response of Private Consumption."""
    # Gov't spending shock
    dG = 0.01 * rho_G ** np.arange(T)

    # Sticky-price model
    sp_ss = ss['sticky_price']
    M, ir = dict(), dict()
    for k in sp_ss.keys():
        M[k] = sp.sp_benchmark(sp_ss[k], T)
        ir[k] = get_ir(M[k], dG)

    # Sticky-wage model
    sw_ss = ss['sticky_wage']
    M['sticky_wage'] = sw.sw_benchmark(sw_ss, T)
    ir['sticky_wage'] = get_ir(M['sticky_wage'], dG)

    # Plot settings
    if figsize is None:
        figsize = (.7*10, .7*4)

    # Plot G and C
    fig, axes = plt.subplots(1, 2, figsize=figsize)
    ax = axes.flatten()
    ax[0].plot(100 * dG[:Tplot], color='black')
    ax[0].set_title(r'Government Spending')
    ax[0].set_xlabel(r'quarters')
    ax[0].set_ylabel(r'\% of steady state $Y$')
    ax[0].set_xlim([-2, 62])
    ax[0].axhline(0, linestyle=':', color='gray')

    alpha = np.linspace(0.3, 1, 11)
    for k in sp_ss.keys():
        if k == 10:
            ax[1].plot(100 * ir[k]['C'][:Tplot], label=r'sticky price, GHH', color=colors[0], alpha=alpha[k])
        elif k == 0:
            ax[1].plot(100 * ir[k]['C'][:Tplot], label=r'sticky price, sep', color=colors[0], alpha=alpha[k])
        else:
            ax[1].plot(100 * ir[k]['C'][:Tplot], color=colors[0], alpha=alpha[k])
    ax[1].plot(100 * ir['sticky_wage']['C'][:Tplot], label=r'sticky wage, sep', linewidth=2, color=colors[3])
    ax[1].axhline(0, linestyle=':', color='gray')
    ax[1].set_xlabel(r'quarters')
    ax[1].set_title(r'Consumption')
    ax[1].set_xlim([-2, 62])
    ax[1].legend()

    plt.tight_layout()
    if savefig:
        plt.savefig('import_export/figures/fig_main_irf.pdf', transparent=True)


def fig8(ss, T=200, rho_G=0.9, Tplot=61, savefig=False, figsize=None):
    """Additional Impulse Responses with Separable Preferences"""
    # Gov't spending shock
    dG = 0.01 * rho_G ** np.arange(T)

    # Sticky-price model
    sp_ss = ss['sticky_price'][0]
    M, ir = dict(), dict()
    M['sticky_price'] = sp.sp_benchmark(sp_ss, T)
    ir['sticky_price'] = get_ir(M['sticky_price'], dG)

    # Sticky-wage model
    sw_ss = ss['sticky_wage']
    M['sticky_wage'] = sw.sw_benchmark(sw_ss, T)
    ir['sticky_wage'] = get_ir(M['sticky_wage'], dG)

    # Plot settings
    if figsize is None:
        figsize = (.7*10, .7*14)

    # Plot
    variables = ['Y', 'C', 'A', 'p', 'w', 'tax_all']
    labels = ['Output', 'Consumption', 'Assets', 'Equity Price', 'Pre-Tax Real Wage', 'Labor tax']
    fig, axes = plt.subplots(4, 2, figsize=figsize)
    ax = axes.flatten()
    for k, (v, l) in enumerate(zip(variables, labels)):
        try:
            ax[k].plot(100 * ir['sticky_price'][v][:Tplot], label=r'sticky price', linewidth=2, color=colors[0])
            ax[k].plot(100 * ir['sticky_wage'][v][:Tplot], label=r'sticky wage', linewidth=2, color=colors[3])
        except KeyError:
            ax[k].plot(np.zeros(Tplot), label=r'sticky wage', linewidth=2, color=colors[3])
        ax[k].axhline(0, linestyle=':', color='gray')
        ax[k].set_title(l)
        ax[k].legend()
        ax[k].set_xlim([-2, 62])
        if np.mod(k, 2) == 0:
            ax[k].set_ylabel(r'\% of steady state $Y$')
    ax[6].plot(100 * dG[:Tplot], color='black')
    ax[6].set_title(r'Government Spending')
    ax[6].set_xlabel(r'quarters')
    ax[6].set_ylabel(r'\% of steady state $Y$')
    ax[6].axhline(0, linestyle=':', color='gray')
    ax[6].set_xlim([-2, 62])
    ax[7].plot(100 * ir['sticky_price']['B'][:Tplot], color='black')
    ax[7].set_title(r'Government Bonds')
    ax[7].set_xlabel(r'quarters')
    ax[7].axhline(0, linestyle=':', color='gray')
    ax[7].set_xlim([-2, 62])
    plt.tight_layout()
    if savefig:
        plt.savefig('import_export\\figures\\fig_irf_macro.pdf', transparent=True)


def fig9(ss, T=200, rho_G=0.9, Tplot=61, nS=25, name='sep', savefig=False, figsize=None):
    """Disaggregated Responses in Sticky-Price Model with Separable Preferences."""
    # Gov't spending shock
    dG = 0.0001 * rho_G ** np.arange(T)

    # Sticky-price model
    ir = sp.sp_benchmark(ss, T, linear=False, dG=dG)

    # Consumption and labor response by patience
    dC1 = 2 * (np.sum(ir['D'][:, :nS, :] * ir['c'][:, :nS, :], axis=(1, 2)) - np.sum(ss['D'][:nS, :] * ss['c'][:nS, :]))
    dC2 = 2 * (np.sum(ir['D'][:, nS:, :] * ir['c'][:, nS:, :], axis=(1, 2)) - np.sum(ss['D'][nS:, :] * ss['c'][nS:, :]))
    dN1 = 2 * (np.sum(ir['D'][:, :nS, :] * ir['n'][:, :nS, :], axis=(1, 2)) - np.sum(ss['D'][:nS, :] * ss['n'][:nS, :]))
    dN2 = 2 * (np.sum(ir['D'][:, nS:, :] * ir['n'][:, nS:, :], axis=(1, 2)) - np.sum(ss['D'][nS:, :] * ss['n'][nS:, :]))

    # Plot settings
    if figsize is None:
        figsize = (.7*10, .7*4)

    # Plot consumption
    fig, axes = plt.subplots(1, 2, figsize=figsize)
    ax = axes.flatten()
    ax[0].plot(10000 * (dC1[:Tplot] + dC2[:Tplot]) / 2, label=r'all', linewidth=2, color='black')
    ax[0].plot(10000 * dC1[:Tplot], label=r'impatient', linewidth=2, linestyle='--', color=colors[1])
    ax[0].plot(10000 * dC2[:Tplot], label=r'patient', linewidth=2, linestyle='-.', color=colors[2])
    ax[0].axhline(0, linestyle=':', color='gray')
    ax[0].set_xlim([-2, 62])
    ax[0].legend()
    ax[0].set_title(r'Consumption')
    ax[0].set_xlabel(r'quarters')
    ax[0].set_ylabel(r'\% of steady state $Y$')

    # Plot labor
    ax[1].plot(10000 * (dN1[:Tplot] + dN2[:Tplot]) / 2, label=r'all', linewidth=2, color='black')
    ax[1].plot(10000 * dN1[:Tplot], label=r'impatient', linewidth=2, linestyle='--', color=colors[1])
    ax[1].plot(10000 * dN2[:Tplot], label=r'patient', linewidth=2, linestyle='-.', color=colors[2])
    ax[1].set_xlim([-2, 62])
    ax[1].axhline(0, linestyle=':', color='gray')
    ax[1].legend()
    ax[1].set_title(r'Hours')
    ax[1].set_xlabel(r'quarters')

    plt.tight_layout()
    if savefig:
        plt.savefig('import_export\\figures\\fig_irf_micro_sp-' + name + '.pdf', transparent=True)

    return ir


def fig10(ss, T=200, rho_G=0.9, Tplot=61, nS=25, savefig=False, figsize=None):
    """Disaggregated Responses in Sticky-Wage Model with Separable Preferences."""
    # Gov't spending shock
    dG = 0.0001 * rho_G ** np.arange(T)

    # Sticky-wage model
    ir = sw.sw_benchmark(ss, T, linear=False, dG=dG)

    # Consumption and labor response by patience
    dC1 = 2 * (np.sum(ir['D'][:, :nS, :] * ir['c'][:, :nS, :], axis=(1, 2)) - np.sum(ss['D'][:nS, :] * ss['c'][:nS, :]))
    dC2 = 2 * (np.sum(ir['D'][:, nS:, :] * ir['c'][:, nS:, :], axis=(1, 2)) - np.sum(ss['D'][nS:, :] * ss['c'][nS:, :]))
    dN = ir['L'] - ss['L']

    # Plot settings
    if figsize is None:
        figsize = (.7*10, .7*4)

    # Plot consumption
    fig, axes = plt.subplots(1, 2, figsize=figsize)
    ax = axes.flatten()
    ax[0].plot(10000 * (dC1[:Tplot] + dC2[:Tplot]) / 2, label=r'all', linewidth=2, color='black')
    ax[0].plot(10000 * dC1[:Tplot], label=r'impatient', linewidth=2, linestyle='--', color=colors[1])
    ax[0].plot(10000 * dC2[:Tplot], label=r'patient', linewidth=2, linestyle='-.', color=colors[2])
    ax[0].axhline(0, linestyle=':', color='gray')
    ax[0].legend()
    ax[0].set_xlim([-2, 62])
    ax[0].set_title(r'Consumption')
    ax[0].set_xlabel(r'quarters')
    ax[0].set_ylabel(r'\% of steady state $Y$')

    # Plot labor
    ax[1].plot(10000 * dN[:Tplot], label=r'all', color='black', linewidth=2)
    ax[1].axhline(0, linestyle=':', color='gray')
    ax[1].legend()
    ax[1].set_xlim([-2, 62])
    ax[1].set_title(r'Hours')
    ax[1].set_xlabel(r'quarters')

    plt.tight_layout()
    if savefig:
        plt.savefig('import_export\\figures\\fig_irf_micro_sw.pdf', transparent=True)

    return ir


'''Part 2: helpers'''


def get_ir(M, dG):
    """Compute impulse responses in levels for all variables in Jacobian."""
    tdout = dict()
    for k in M.keys():
        if 'G' in M[k].keys():
            tdout[k] = M[k]['G'] @ dG
    return tdout
